{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uPojRlNQPJub"
      },
      "source": [
        "Licensed under the Apache License, Version 2.0\n",
        "# Imports"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oOtnY2y8Om-1"
      },
      "source": [
        "Import the required libraries."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-12-04T23:21:56.11266Z",
          "iopub.status.busy": "2021-12-04T23:21:56.111768Z",
          "iopub.status.idle": "2021-12-04T23:21:56.119068Z",
          "shell.execute_reply": "2021-12-04T23:21:56.117706Z",
          "shell.execute_reply.started": "2021-12-04T23:21:56.112602Z"
        },
        "id": "1Mf4kbp8yOxd",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "import re\n",
        "from typing import Dict, List, Optional, Text, Tuple\n",
        "import matplotlib.pyplot as plt\n",
        "from matplotlib import colors\n",
        "\n",
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aVb2hhOcgVwU"
      },
      "source": [
        "# Load the dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "awWJ00JeOyuO"
      },
      "source": [
        "Enter the file pattern of the dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-12-04T23:21:57.461966Z",
          "iopub.status.busy": "2021-12-04T23:21:57.461167Z",
          "iopub.status.idle": "2021-12-04T23:21:57.466727Z",
          "shell.execute_reply": "2021-12-04T23:21:57.465849Z",
          "shell.execute_reply.started": "2021-12-04T23:21:57.461917Z"
        },
        "id": "fhHMUbBoOg0k",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "file_pattern = '/kaggle/input/next-day-wildfire-spread/next_day_wildfire_spread_train*'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OabUJMGqO9NE"
      },
      "source": [
        "Run the following three cells to define the required library functions for loading the data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-12-04T23:21:58.421589Z",
          "iopub.status.busy": "2021-12-04T23:21:58.421087Z",
          "iopub.status.idle": "2021-12-04T23:21:58.434217Z",
          "shell.execute_reply": "2021-12-04T23:21:58.433505Z",
          "shell.execute_reply.started": "2021-12-04T23:21:58.421551Z"
        },
        "id": "GTTV3tjjCcdn",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "# Constants for the data reader\n",
        "\n",
        "INPUT_FEATURES = ['elevation', 'th', 'vs',  'tmmn', 'tmmx', 'sph',\n",
        "                  'pr', 'pdsi', 'NDVI', 'population', 'erc', 'PrevFireMask']\n",
        "\n",
        "OUTPUT_FEATURES = ['FireMask', ]\n",
        "\n",
        "# Data statistics\n",
        "# For each variable, the statistics are ordered in the form:\n",
        "# (min_clip, max_clip, mean, std)\n",
        "DATA_STATS = {\n",
        "    # 0.1 percentile, 99.9 percentile\n",
        "    'elevation': (0.0, 3141.0, 657.3003, 649.0147),\n",
        "    # Pressure\n",
        "    # 0.1 percentile, 99.9 percentile\n",
        "    'pdsi': (-6.1298, 7.8760, -0.0053, 2.6823),\n",
        "    'NDVI': (-9821.0, 9996.0, 5157.625, 2466.6677),\n",
        "    # Precipitation in mm.\n",
        "    # Negative values make no sense, so min is set to 0.\n",
        "    # 0., 99.9 percentile\n",
        "    'pr': (0.0, 44.5304, 1.7398051, 4.4828),\n",
        "    # Specific humidity ranges from 0 to 100%.\n",
        "    'sph': (0., 1., 0.0071658953, 0.0042835088),\n",
        "    # Wind direction in degrees clockwise from north.\n",
        "    # Thus min set to 0 and max set to 360.\n",
        "    'th': (0., 360.0, 190.3298, 72.5985),\n",
        "    # Min/max temperature in Kelvin.\n",
        "    # -20 degree C, 99.9 percentile\n",
        "    'tmmn': (253.15, 298.9489, 281.08768, 8.9824),\n",
        "    # -20 degree C, 99.9 percentile\n",
        "    'tmmx': (253.15, 315.0923, 295.17383, 9.8155),\n",
        "    # Wind speed.\n",
        "    # Negative values do not make sense, given there is a wind direction.\n",
        "    # 0., 99.9 percentile\n",
        "    'vs': (0.0, 10.0243, 3.8501, 1.4110),\n",
        "    # NFDRS fire danger index energy release component expressed in BTU's per\n",
        "    # square foot.\n",
        "    # Negative values do not make sense. Thus min set to zero.\n",
        "    # 0., 99.9 percentile\n",
        "    'erc': (0.0, 106.2489, 37.3263, 20.8460),\n",
        "    # Population\n",
        "    # min, 99.9 percentile\n",
        "    'population': (0., 2534.0630, 25.5314, 154.7233),\n",
        "    # We don't want to normalize the FireMasks.\n",
        "    'PrevFireMask': (-1., 1., 0., 1.),\n",
        "    'FireMask': (-1., 1., 0., 1.)\n",
        "}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-12-04T23:21:58.898032Z",
          "iopub.status.busy": "2021-12-04T23:21:58.897564Z",
          "iopub.status.idle": "2021-12-04T23:21:58.907504Z",
          "shell.execute_reply": "2021-12-04T23:21:58.9067Z",
          "shell.execute_reply.started": "2021-12-04T23:21:58.897995Z"
        },
        "id": "QqGYv21hD-2q",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "\"\"\"Library of common functions used in deep learning neural networks.\n",
        "\"\"\"\n",
        "def random_crop_input_and_output_images(\n",
        "    input_img: tf.Tensor,\n",
        "    output_img: tf.Tensor,\n",
        "    sample_size: int,\n",
        "    num_in_channels: int,\n",
        "    num_out_channels: int,\n",
        ") -\u003e Tuple[tf.Tensor, tf.Tensor]:\n",
        "  \"\"\"Randomly axis-align crop input and output image tensors.\n",
        "\n",
        "  Args:\n",
        "    input_img: Tensor with dimensions HWC.\n",
        "    output_img: Tensor with dimensions HWC.\n",
        "    sample_size: Side length (square) to crop to.\n",
        "    num_in_channels: Number of channels in `input_img`.\n",
        "    num_out_channels: Number of channels in `output_img`.\n",
        "  Returns:\n",
        "    input_img: Tensor with dimensions HWC.\n",
        "    output_img: Tensor with dimensions HWC.\n",
        "  \"\"\"\n",
        "  combined = tf.concat([input_img, output_img], axis=2)\n",
        "  combined = tf.image.random_crop(\n",
        "      combined,\n",
        "      [sample_size, sample_size, num_in_channels + num_out_channels])\n",
        "  input_img = combined[:, :, 0:num_in_channels]\n",
        "  output_img = combined[:, :, -num_out_channels:]\n",
        "  return input_img, output_img\n",
        "\n",
        "\n",
        "def center_crop_input_and_output_images(\n",
        "    input_img: tf.Tensor,\n",
        "    output_img: tf.Tensor,\n",
        "    sample_size: int,\n",
        ") -\u003e Tuple[tf.Tensor, tf.Tensor]:\n",
        "  \"\"\"Calls `tf.image.central_crop` on input and output image tensors.\n",
        "\n",
        "  Args:\n",
        "    input_img: Tensor with dimensions HWC.\n",
        "    output_img: Tensor with dimensions HWC.\n",
        "    sample_size: Side length (square) to crop to.\n",
        "  Returns:\n",
        "    input_img: Tensor with dimensions HWC.\n",
        "    output_img: Tensor with dimensions HWC.\n",
        "  \"\"\"\n",
        "  central_fraction = sample_size / input_img.shape[0]\n",
        "  input_img = tf.image.central_crop(input_img, central_fraction)\n",
        "  output_img = tf.image.central_crop(output_img, central_fraction)\n",
        "  return input_img, output_img"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-12-04T23:21:59.359837Z",
          "iopub.status.busy": "2021-12-04T23:21:59.359068Z",
          "iopub.status.idle": "2021-12-04T23:21:59.392567Z",
          "shell.execute_reply": "2021-12-04T23:21:59.391623Z",
          "shell.execute_reply.started": "2021-12-04T23:21:59.359778Z"
        },
        "id": "VBvI9FuGEC09",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "\"\"\"Dataset reader for Earth Engine data.\"\"\"\n",
        "\n",
        "def _get_base_key(key: Text) -\u003e Text:\n",
        "  \"\"\"Extracts the base key from the provided key.\n",
        "\n",
        "  Earth Engine exports TFRecords containing each data variable with its\n",
        "  corresponding variable name. In the case of time sequences, the name of the\n",
        "  data variable is of the form 'variable_1', 'variable_2', ..., 'variable_n',\n",
        "  where 'variable' is the name of the variable, and n the number of elements\n",
        "  in the time sequence. Extracting the base key ensures that each step of the\n",
        "  time sequence goes through the same normalization steps.\n",
        "  The base key obeys the following naming pattern: '[a-zA-Z]+'\n",
        "  For instance, for an input key 'variable_1', this function returns 'variable'.\n",
        "  For an input key 'variable', this function simply returns 'variable'.\n",
        "\n",
        "  Args:\n",
        "    key: Input key.\n",
        "\n",
        "  Returns:\n",
        "    The corresponding base key.\n",
        "\n",
        "  Raises:\n",
        "    ValueError when `key` does not match the expected pattern.\n",
        "  \"\"\"\n",
        "  match = re.match(r'[a-zA-Z]+', key)\n",
        "  if match:\n",
        "    return match.group(1)\n",
        "  raise ValueError(\n",
        "      f'The provided key does not match the expected pattern: {key}'\n",
        "\n",
        "\n",
        "def _clip_and_rescale(inputs: tf.Tensor, key: Text) -\u003e tf.Tensor:\n",
        "  \"\"\"Clips and rescales inputs with the stats corresponding to `key`.\n",
        "\n",
        "  Args:\n",
        "    inputs: Inputs to clip and rescale.\n",
        "    key: Key describing the inputs.\n",
        "\n",
        "  Returns:\n",
        "    Clipped and rescaled input.\n",
        "\n",
        "  Raises:\n",
        "    ValueError if there are no data statistics available for `key`.\n",
        "  \"\"\"\n",
        "  base_key = _get_base_key(key)\n",
        "  if base_key not in DATA_STATS:\n",
        "    raise ValueError(\n",
        "        'No data statistics available for the requested key: {}.'.format(key))\n",
        "  min_val, max_val, _, _ = DATA_STATS[base_key]\n",
        "  inputs = tf.clip_by_value(inputs, min_val, max_val)\n",
        "  return tf.math.divide_no_nan((inputs - min_val), (max_val - min_val))\n",
        "\n",
        "\n",
        "def _clip_and_normalize(inputs: tf.Tensor, key: Text) -\u003e tf.Tensor:\n",
        "  \"\"\"Clips and normalizes inputs with the stats corresponding to `key`.\n",
        "\n",
        "  Args:\n",
        "    inputs: Inputs to clip and normalize.\n",
        "    key: Key describing the inputs.\n",
        "\n",
        "  Returns:\n",
        "    Clipped and normalized input.\n",
        "\n",
        "  Raises:\n",
        "    ValueError if there are no data statistics available for `key`.\n",
        "  \"\"\"\n",
        "  base_key = _get_base_key(key)\n",
        "  if base_key not in DATA_STATS:\n",
        "    raise ValueError(\n",
        "        'No data statistics available for the requested key: {}.'.format(key))\n",
        "  min_val, max_val, mean, std = DATA_STATS[base_key]\n",
        "  inputs = tf.clip_by_value(inputs, min_val, max_val)\n",
        "  inputs = inputs - mean\n",
        "  return tf.math.divide_no_nan(inputs, std)\n",
        "\n",
        "def _get_features_dict(\n",
        "    sample_size: int,\n",
        "    features: List[Text],\n",
        ") -\u003e Dict[Text, tf.io.FixedLenFeature]:\n",
        "  \"\"\"Creates a features dictionary for TensorFlow IO.\n",
        "\n",
        "  Args:\n",
        "    sample_size: Size of the input tiles (square).\n",
        "    features: List of feature names.\n",
        "\n",
        "  Returns:\n",
        "    A features dictionary for TensorFlow IO.\n",
        "  \"\"\"\n",
        "  sample_shape = [sample_size, sample_size]\n",
        "  features = set(features)\n",
        "  columns = [\n",
        "      tf.io.FixedLenFeature(shape=sample_shape, dtype=tf.float32)\n",
        "      for _ in features\n",
        "  ]\n",
        "  return dict(zip(features, columns))\n",
        "\n",
        "\n",
        "def _parse_fn(\n",
        "    example_proto: tf.train.Example, data_size: int, sample_size: int,\n",
        "    num_in_channels: int, clip_and_normalize: bool,\n",
        "    clip_and_rescale: bool, random_crop: bool, center_crop: bool,\n",
        ") -\u003e Tuple[tf.Tensor, tf.Tensor]:\n",
        "  \"\"\"Reads a serialized example.\n",
        "\n",
        "  Args:\n",
        "    example_proto: A TensorFlow example protobuf.\n",
        "    data_size: Size of tiles (square) as read from input files.\n",
        "    sample_size: Size the tiles (square) when input into the model.\n",
        "    num_in_channels: Number of input channels.\n",
        "    clip_and_normalize: True if the data should be clipped and normalized.\n",
        "    clip_and_rescale: True if the data should be clipped and rescaled.\n",
        "    random_crop: True if the data should be randomly cropped.\n",
        "    center_crop: True if the data should be cropped in the center.\n",
        "\n",
        "  Returns:\n",
        "    (input_img, output_img) tuple of inputs and outputs to the ML model.\n",
        "  \"\"\"\n",
        "  if (random_crop and center_crop):\n",
        "    raise ValueError('Cannot have both random_crop and center_crop be True')\n",
        "  input_features, output_features = INPUT_FEATURES, OUTPUT_FEATURES\n",
        "  feature_names = input_features + output_features\n",
        "  features_dict = _get_features_dict(data_size, feature_names)\n",
        "  features = tf.io.parse_single_example(example_proto, features_dict)\n",
        "\n",
        "  if clip_and_normalize:\n",
        "    inputs_list = [\n",
        "        _clip_and_normalize(features.get(key), key) for key in input_features\n",
        "    ]\n",
        "  elif clip_and_rescale:\n",
        "    inputs_list = [\n",
        "        _clip_and_rescale(features.get(key), key) for key in input_features\n",
        "    ]\n",
        "  else:\n",
        "    inputs_list = [features.get(key) for key in input_features]\n",
        "\n",
        "  inputs_stacked = tf.stack(inputs_list, axis=0)\n",
        "  input_img = tf.transpose(inputs_stacked, [1, 2, 0])\n",
        "\n",
        "  outputs_list = [features.get(key) for key in output_features]\n",
        "  assert outputs_list, 'outputs_list should not be empty'\n",
        "  outputs_stacked = tf.stack(outputs_list, axis=0)\n",
        "\n",
        "  outputs_stacked_shape = outputs_stacked.get_shape().as_list()\n",
        "  assert len(outputs_stacked.shape) == 3, ('outputs_stacked should be rank 3'\n",
        "                                            'but dimensions of outputs_stacked'\n",
        "                                            f' are {outputs_stacked_shape}')\n",
        "  output_img = tf.transpose(outputs_stacked, [1, 2, 0])\n",
        "\n",
        "  if random_crop:\n",
        "    input_img, output_img = random_crop_input_and_output_images(\n",
        "        input_img, output_img, sample_size, num_in_channels, 1)\n",
        "  if center_crop:\n",
        "    input_img, output_img = center_crop_input_and_output_images(\n",
        "        input_img, output_img, sample_size)\n",
        "  return input_img, output_img\n",
        "\n",
        "\n",
        "def get_dataset(file_pattern: Text, data_size: int, sample_size: int,\n",
        "                batch_size: int, num_in_channels: int, compression_type: Text,\n",
        "                clip_and_normalize: bool, clip_and_rescale: bool,\n",
        "                random_crop: bool, center_crop: bool) -\u003e tf.data.Dataset:\n",
        "  \"\"\"Gets the dataset from the file pattern.\n",
        "\n",
        "  Args:\n",
        "    file_pattern: Input file pattern.\n",
        "    data_size: Size of tiles (square) as read from input files.\n",
        "    sample_size: Size the tiles (square) when input into the model.\n",
        "    batch_size: Batch size.\n",
        "    num_in_channels: Number of input channels.\n",
        "    compression_type: Type of compression used for the input files.\n",
        "    clip_and_normalize: True if the data should be clipped and normalized, False\n",
        "      otherwise.\n",
        "    clip_and_rescale: True if the data should be clipped and rescaled, False\n",
        "      otherwise.\n",
        "    random_crop: True if the data should be randomly cropped.\n",
        "    center_crop: True if the data shoulde be cropped in the center.\n",
        "\n",
        "  Returns:\n",
        "    A TensorFlow dataset loaded from the input file pattern, with features\n",
        "    described in the constants, and with the shapes determined from the input\n",
        "    parameters to this function.\n",
        "  \"\"\"\n",
        "  if (clip_and_normalize and clip_and_rescale):\n",
        "    raise ValueError('Cannot have both normalize and rescale.')\n",
        "  dataset = tf.data.Dataset.list_files(file_pattern)\n",
        "  dataset = dataset.interleave(\n",
        "      lambda x: tf.data.TFRecordDataset(x, compression_type=compression_type),\n",
        "      num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "  dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n",
        "  dataset = dataset.map(\n",
        "      lambda x: _parse_fn(  # pylint: disable=g-long-lambda\n",
        "          x, data_size, sample_size, num_in_channels, clip_and_normalize,\n",
        "          clip_and_rescale, random_crop, center_crop),\n",
        "      num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "  dataset = dataset.batch(batch_size)\n",
        "  dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n",
        "  return dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rdNDytnsPTVK"
      },
      "source": [
        "Load the dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-12-04T23:22:00.260248Z",
          "iopub.status.busy": "2021-12-04T23:22:00.25974Z",
          "iopub.status.idle": "2021-12-04T23:22:00.852967Z",
          "shell.execute_reply": "2021-12-04T23:22:00.851977Z",
          "shell.execute_reply.started": "2021-12-04T23:22:00.260208Z"
        },
        "id": "X1jBBEinQbM0",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "dataset = get_dataset(\n",
        "      file_pattern,\n",
        "      data_size=64,\n",
        "      sample_size=32,\n",
        "      batch_size=100,\n",
        "      num_in_channels=12,\n",
        "      compression_type=None,\n",
        "      clip_and_normalize=False,\n",
        "      clip_and_rescale=False,\n",
        "      random_crop=True,\n",
        "      center_crop=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ca8USco0PdG3"
      },
      "source": [
        "TF Datasets are loaded lazily, so materialize the first batch of inputs and labels."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-12-04T23:22:01.317784Z",
          "iopub.status.busy": "2021-12-04T23:22:01.317477Z",
          "iopub.status.idle": "2021-12-04T23:22:01.591395Z",
          "shell.execute_reply": "2021-12-04T23:22:01.590278Z",
          "shell.execute_reply.started": "2021-12-04T23:22:01.317751Z"
        },
        "id": "Ml7Rg8aCQiTT",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "inputs, labels = next(iter(dataset))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bzCgCoxtgP-f"
      },
      "source": [
        "# Plotting function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-12-04T23:22:02.409458Z",
          "iopub.status.busy": "2021-12-04T23:22:02.4082Z",
          "iopub.status.idle": "2021-12-04T23:22:02.41446Z",
          "shell.execute_reply": "2021-12-04T23:22:02.413495Z",
          "shell.execute_reply.started": "2021-12-04T23:22:02.409372Z"
        },
        "id": "bnG0_l_ChjUt",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "TITLES = [\n",
        "  'Elevation',\n",
        "  'Wind\\ndirection',\n",
        "  'Wind\\nvelocity',\n",
        "  'Min\\ntemp',\n",
        "  'Max\\ntemp',\n",
        "  'Humidity',\n",
        "  'Precip',\n",
        "  'Drought',\n",
        "  'Vegetation',\n",
        "  'Population\\ndensity',\n",
        "  'Energy\\nrelease\\ncomponent',\n",
        "  'Previous\\nfire\\nmask',\n",
        "  'Fire\\nmask'\n",
        "]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-12-04T23:22:03.029066Z",
          "iopub.status.busy": "2021-12-04T23:22:03.028554Z",
          "iopub.status.idle": "2021-12-04T23:22:03.037678Z",
          "shell.execute_reply": "2021-12-04T23:22:03.036364Z",
          "shell.execute_reply.started": "2021-12-04T23:22:03.02903Z"
        },
        "id": "G0E6lWR9beD0",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "n_rows = 5\n",
        "n_features = inputs.shape[3]\n",
        "CMAP = colors.ListedColormap(['black', 'silver', 'orangered'])\n",
        "BOUNDS = [-1, -0.1, 0.001, 1]\n",
        "NORM = colors.BoundaryNorm(BOUNDS, CMAP.N)\n",
        "keys = INPUT_FEATURES"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2021-12-04T23:22:03.73577Z",
          "iopub.status.busy": "2021-12-04T23:22:03.734649Z",
          "iopub.status.idle": "2021-12-04T23:22:07.29126Z",
          "shell.execute_reply": "2021-12-04T23:22:07.290443Z",
          "shell.execute_reply.started": "2021-12-04T23:22:03.735694Z"
        },
        "id": "sPtKQzQv71J_",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "fig = plt.figure(figsize=(15,6.5))\n",
        "\n",
        "for i in range(n_rows):\n",
        "  for j in range(n_features + 1):\n",
        "    plt.subplot(n_rows, n_features + 1, i * (n_features + 1) + j + 1)\n",
        "    if i == 0:\n",
        "      plt.title(TITLES[j], fontsize=13)\n",
        "    if j \u003c n_features - 1:\n",
        "      plt.imshow(inputs[i, :, :, j], cmap='viridis')\n",
        "    if j == n_features - 1:\n",
        "      plt.imshow(inputs[i, :, :, -1], cmap=CMAP, norm=NORM)\n",
        "    if j == n_features:\n",
        "      plt.imshow(labels[i, :, :, 0], cmap=CMAP, norm=NORM) \n",
        "    plt.axis('off')\n",
        "plt.tight_layout()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "kaggle_next_day_wildfire_demo.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1l6l66eew8rD1LcR1nx7seYhXt194U905",
          "timestamp": 1639619194522
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
